{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZS8z-_KeywY9"
      },
      "source": [
        "# TF Lattice Canned Estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/canned_estimators\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org에서 보기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ko/lattice/tutorials/canned_estimators.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab에서 실행하기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ko/lattice/tutorials/canned_estimators.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub에서소스 보기</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ko/lattice/tutorials/canned_estimators.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">노트북 다운로드하기</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCpl-9WDVq9d"
      },
      "source": [
        "## 개요\n",
        "\n",
        "준비된 estimator는 일반적인 사용 사례를 위해 TFL 모델을 훈련하는 빠르고 쉬운 방법입니다. 이 가이드에서는 TFL canned estimator를 만드는 데 필요한 단계를 설명합니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## 설정"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "TF Lattice 패키지 설치하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "필수 패키지 가져오기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "FbZDk8bIx8ig"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import copy\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "from tensorflow import feature_column as fc\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "svPuM6QNxlrH"
      },
      "source": [
        "UCI Statlog(Heart) 데이터세트 다운로드하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "j-k1qTR_yvBl"
      },
      "outputs": [],
      "source": [
        "csv_file = tf.keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/applied-dl/heart.csv')\n",
        "df = pd.read_csv(csv_file)\n",
        "target = df.pop('target')\n",
        "train_size = int(len(df) * 0.8)\n",
        "train_x = df[:train_size]\n",
        "train_y = target[:train_size]\n",
        "test_x = df[train_size:]\n",
        "test_y = target[train_size:]\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nKkAw12SxvGG"
      },
      "source": [
        "이 가이드에서 훈련에 사용되는 기본값 설정하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "1T6GFI9F6mcG"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.01\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "PREFITTING_NUM_EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TGfzhPHzpix"
      },
      "source": [
        "## 특성 열\n",
        "\n",
        "다른 TF estimator와 마찬가지로 데이터는 일반적으로 input_fn을 통해 estimator로 전달되어야 하며 [FeatureColumns](https://www.tensorflow.org/guide/feature_columns)를 사용하여 구문 분석됩니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DCIUz8apzs0l"
      },
      "outputs": [],
      "source": [
        "# Feature columns.\n",
        "# - age\n",
        "# - sex\n",
        "# - cp        chest pain type (4 values)\n",
        "# - trestbps  resting blood pressure\n",
        "# - chol      serum cholestoral in mg/dl\n",
        "# - fbs       fasting blood sugar > 120 mg/dl\n",
        "# - restecg   resting electrocardiographic results (values 0,1,2)\n",
        "# - thalach   maximum heart rate achieved\n",
        "# - exang     exercise induced angina\n",
        "# - oldpeak   ST depression induced by exercise relative to rest\n",
        "# - slope     the slope of the peak exercise ST segment\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n",
        "feature_columns = [\n",
        "    fc.numeric_column('age', default_value=-1),\n",
        "    fc.categorical_column_with_vocabulary_list('sex', [0, 1]),\n",
        "    fc.numeric_column('cp'),\n",
        "    fc.numeric_column('trestbps', default_value=-1),\n",
        "    fc.numeric_column('chol'),\n",
        "    fc.categorical_column_with_vocabulary_list('fbs', [0, 1]),\n",
        "    fc.categorical_column_with_vocabulary_list('restecg', [0, 1, 2]),\n",
        "    fc.numeric_column('thalach'),\n",
        "    fc.categorical_column_with_vocabulary_list('exang', [0, 1]),\n",
        "    fc.numeric_column('oldpeak'),\n",
        "    fc.categorical_column_with_vocabulary_list('slope', [0, 1, 2]),\n",
        "    fc.numeric_column('ca'),\n",
        "    fc.categorical_column_with_vocabulary_list(\n",
        "        'thal', ['normal', 'fixed', 'reversible']),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEZstmtT2CA3"
      },
      "source": [
        "준비된 TFL estimator는 특성 열의 유형을 사용하여 사용할 보정 레이어 유형을 결정합니다. 숫자 특성 열에는 `tfl.layers.PWLCalibration`를, 범주형 특성 열에는 `tfl.layers.CategoricalCalibration` 레이어가 사용됩니다.\n",
        "\n",
        "범주형 특성 열은 임베딩 특성 열로 래핑되지 않고 estimator에 직접 공급됩니다."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H_LoW_9m5OFL"
      },
      "source": [
        "## input_fn 만들기\n",
        "\n",
        "다른 estimator와 마찬가지로 input_fn을 사용하여 훈련 및 평가를 위해 모델에 데이터를 공급할 수 있습니다. TFL estimator는 특성의 분위수를 자동으로 계산하고 이를 PWL 보정 레이어의 입력 키포인트로 사용할 수 있습니다. 이를 위해서는 훈련 input_fn과 유사하지만 단일 epoch 또는 데이터의 하위 샘플이 있는 `feature_analysis_input_fn`를 전달해야 합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFVy1Efy5NKD"
      },
      "outputs": [],
      "source": [
        "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=NUM_EPOCHS,\n",
        "    num_threads=1)\n",
        "\n",
        "# feature_analysis_input_fn is used to collect statistics about the input.\n",
        "feature_analysis_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    # Note that we only need one pass over the data.\n",
        "    num_epochs=1,\n",
        "    num_threads=1)\n",
        "\n",
        "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=test_x,\n",
        "    y=test_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    num_threads=1)\n",
        "\n",
        "# Serving input fn is used to create saved models.\n",
        "serving_input_fn = (\n",
        "    tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "        feature_spec=fc.make_parse_example_spec(feature_columns)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQlzREcm2Wbj"
      },
      "source": [
        "## 특성 구성\n",
        "\n",
        "특성 보정 및 특성별 구성은 `tfl.configs.FeatureConfig`를 사용하여 설정됩니다. 특성 구성에는 단조 제약 조건, 특성별 정규화(`tfl.configs.RegularizerConfig` 참조) 및 격자 모델에 대한 격자 크기가 포함됩니다.\n",
        "\n",
        "입력 특성에 대한 구성이 정의되지 않은 경우 `tfl.config.FeatureConfig`의 기본 구성이 사용됩니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vD0tNpiO3p9c"
      },
      "outputs": [],
      "source": [
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='age',\n",
        "        lattice_size=3,\n",
        "        # By default, input keypoints of pwl are quantiles of the feature.\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_clip_max=100,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='cp',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        # Keypoints can be uniformly spaced.\n",
        "        pwl_calibration_input_keypoints='uniform',\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='chol',\n",
        "        # Explicit input keypoint initialization.\n",
        "        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "        monotonicity='increasing',\n",
        "        # Calibration can be forced to span the full output range by clamping.\n",
        "        pwl_calibration_clamp_min=True,\n",
        "        pwl_calibration_clamp_max=True,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='fbs',\n",
        "        # Partial monotonicity: output(0) <= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='trestbps',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='decreasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thalach',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='decreasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='restecg',\n",
        "        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)\n",
        "        monotonicity=[(0, 1), (0, 2)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='exang',\n",
        "        # Partial monotonicity: output(0) <= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='oldpeak',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='slope',\n",
        "        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)\n",
        "        monotonicity=[(0, 1), (1, 2)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='ca',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thal',\n",
        "        # Partial monotonicity:\n",
        "        # output(normal) <= output(fixed)\n",
        "        # output(normal) <= output(reversible)        \n",
        "        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LKBULveZ4mr3"
      },
      "source": [
        "## 보정된 선형 모델\n",
        "\n",
        "준비된 TFL estimator를 구성하려면 `tfl.configs`에서 모델 구성을 갖추세요. 보정된 선형 모델은 `tfl.configs.CalibratedLinearConfig`를 사용하여 구성됩니다. 입력 특성에 부분 선형 및 범주형 보정을 적용한 다음 선형 조합 및 선택적 출력 부분 선형 보정을 적용합니다. 출력 보정을 사용하거나 출력 경계가 지정된 경우 선형 레이어는 보정된 입력에 가중치 평균을 적용합니다.\n",
        "\n",
        "이 예제에서는 처음 5개 특성에 대해 보정된 선형 모델을 만듭니다. `tfl.visualization`을 사용하여 보정 플롯으로 모델 그래프를 플롯합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "diRRozio4sAL"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the estimator.\n",
        "model_config = tfl.configs.CalibratedLinearConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    use_bias=True,\n",
        "    output_calibration=True,\n",
        "    regularizer_configs=[\n",
        "        # Regularizer for the output calibrator.\n",
        "        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns[:5],\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Calibrated linear test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zWzPM2_p977t"
      },
      "source": [
        "## 보정된 격자 모델\n",
        "\n",
        "보정된 격자 모델은 `tfl.configs.CalibratedLatticeConfig`를 사용하여 구성됩니다. 보정된 격자 모델은 입력 특성에 구간별 선형 및 범주형 보정을 적용한 다음 격자 모델 및 선택적 출력 구간별 선형 보정을 적용합니다.\n",
        "\n",
        "이 예제에서는 처음 5개의 특성에 대해 보정된 격자 모델을 만듭니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C6EvVpKW4BbC"
      },
      "outputs": [],
      "source": [
        "# This is calibrated lattice model: Inputs are calibrated, then combined\n",
        "# non-linearly using a lattice layer.\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    regularizer_configs=[\n",
        "        # Torsion regularizer applied to the lattice to make it more linear.\n",
        "        tfl.configs.RegularizerConfig(name='torsion', l2=1e-4),\n",
        "        # Globally defined calibration regularizer is applied to all features.\n",
        "        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns[:5],\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Calibrated lattice test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9494K_ZBKFcm"
      },
      "source": [
        "## 보정된 격자 앙상블\n",
        "\n",
        "특성 수가 많으면 앙상블 모델을 사용할 수 있습니다. 이 모델은 특성의 하위 집합에 대해 여러 개의 작은 격자를 만들고, 하나의 거대한 격자를 만드는 대신 출력을 평균화합니다. 앙상블 격자 모델은 `tfl.configs.CalibratedLatticeEnsembleConfig`를 사용하여 구성됩니다. 보정된 격자 앙상블 모델은 입력 특성에 구간별 선형 및 범주형 보정을 적용한 다음 격자 모델 앙상블과 선택적 출력 구간별 선형 보정을 적용합니다.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KjrzziMFKuCB"
      },
      "source": [
        "### 무작위 격자 앙상블\n",
        "\n",
        "다음 모델 구성은 각 격자에 대해 무작위의 특성 하위 집합을 사용합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YBSS7dLjKExq"
      },
      "outputs": [],
      "source": [
        "# This is random lattice ensemble model with separate calibration:\n",
        "# model output is the average output of separatly calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Random ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7uyO8s97FGJM"
      },
      "source": [
        "### RTL 레이어 무작위 격자 앙상블\n",
        "\n",
        "다음 모델 구성은 각 격자에 대해 무작위의 특성 하위 집합을 사용하는 `tfl.layers.RTL` 레이어를 사용합니다. `tfl.layers.RTL`은 단조 제약 조건만 지원하며 모든 특성에 대해 동일한 격자 크기를 가져야 하고 특성별 정규화가 없어야 합니다. `tfl.layers.RTL` 레이어를 사용하면 별도의 `tfl.layers.Lattice` 인스턴스를 사용하는 것보다 훨씬 더 큰 앙상블로 확장할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8v7dKg-FF7iz"
      },
      "outputs": [],
      "source": [
        "# Make sure our feature configs have the same lattice size, no per-feature\n",
        "# regularization, and only monotonicity constraints.\n",
        "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n",
        "for feature_config in rtl_layer_feature_configs:\n",
        "  feature_config.lattice_size = 2\n",
        "  feature_config.unimodality = 'none'\n",
        "  feature_config.reflects_trust_in = None\n",
        "  feature_config.dominates = None\n",
        "  feature_config.regularizer_configs = None\n",
        "# This is RTL layer ensemble model with separate calibration:\n",
        "# model output is the average output of separately calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    lattices='rtl_layer',\n",
        "    feature_configs=rtl_layer_feature_configs,\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Random ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSXEaYAULRvf"
      },
      "source": [
        "###  Crystals 격자 앙상블\n",
        "\n",
        "TFL은 또한 [Crystals](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices)라고 하는 휴리스틱 특성 배열 알고리즘을 제공합니다. Crystals 알고리즘은 먼저 쌍별 특성 상호 작용을 예측하는 *사전 적합 모델*을 훈련합니다. 그런 다음 비 선형 상호 작용이 더 많은 특성이 동일한 격자에 있도록 최종 앙상블을 정렬합니다.\n",
        "\n",
        "Crystals 모델의 경우 위에서 설명한 대로 사전 적합 모델을 훈련하는 데 사용되는 `prefitting_input_fn`도 제공해야 합니다. 사전 적합 모델은 완전하게 훈련될 필요가 없기에 몇 번의 epoch면 충분합니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FjQKh9saMaFu"
      },
      "outputs": [],
      "source": [
        "prefitting_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=PREFITTING_NUM_EPOCHS,\n",
        "    num_threads=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fVnZpwX8MtPi"
      },
      "source": [
        "그런 다음 모델 구성에서 `lattice='crystals'` 를 설정하여 Crystal 모델을 만들 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f4awRMDe-eMv"
      },
      "outputs": [],
      "source": [
        "# This is Crystals ensemble model with separate calibration: model output is\n",
        "# the average output of separatly calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='crystals',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    # prefitting_input_fn is required to train the prefitting model.\n",
        "    prefitting_input_fn=prefitting_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    prefitting_optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Crystals ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Isb2vyLAVBM1"
      },
      "source": [
        "`tfl.visualization` 모듈을 사용하여 더 자세한 정보로 특성 calibrator를 플롯할 수 있습니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DJPaREuWS2sg"
      },
      "outputs": [],
      "source": [
        "_ = tfl.visualization.plot_feature_calibrator(model_graph, \"age\")\n",
        "_ = tfl.visualization.plot_feature_calibrator(model_graph, \"restecg\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "canned_estimators.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
